import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))#切换目录到当前脚本下
import cv2
import numpy as np
import onnxruntime as ort
import time
import random
import serial
def plot_one_box(x, img, color=None,label=None,line_thickness=None):
    tl = (
        line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2)+1
    )  # line/font thickness
    color = color or [random.randint(0,255)for _ in range(3)]
    c1, c2= (int(x[0]),int(x[1])), (int(x[2]),int(x[3]))
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label,0,fontScale = tl / 3,thickness = tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(
            img,
            label,
            (c1[0], c1[1] - 2),
            0,
            tl / 3,
            [225, 255,255],
            thickness = tf,
            lineType = cv2.LINE_AA,
        )
def _make_grid( nx,ny) :
    xv, yv = np.meshgrid(np. arange(ny) , np.arange(nx))
    return np.stack((xv, yv), 2).reshape((-1,2)).astype(np.float32)
def cal_outputs(outs , nl,na,model_w, model_h,anchor_grid,stride) :
    row_ind = 0
    grid = [np.zeros(1)]* nl
    for i in range(nl):
        h, w = int( model_w/ stride[i]), int( model_h / stride[i])
        length = int(na* h * w)

        if grid[i].shape[2:4] != (h, w):
            grid[i] = _make_grid(w,h)
        outs[row_ind: row_ind + length,0:2] = (outs[row_ind : row_ind + length,0:2]*2.- 0.5 + np.tile(
            grid[i],(na,1))) * int(stride[i])
        outs[row_ind :row_ind + length,2:4] = (outs[row_ind :row_ind + length,2:4]*2)** 2* np.repeat(
            anchor_grid[i],h * w, axis=0)
        row_ind += length
    return outs
def post_process_opencv(outputs , model_h,model_w,img_h,img_w,thred_nms , thred_cond):
    conf = outputs[ :,4].tolist()
    c_x = outputs[ :,0]/model_w*img_w
    c_y = outputs[:,1]/model_h*img_h
    w= outputs[ :,2]/model_w*img_w
    h= outputs[ :,3]/model_h*img_h
    p_cls = outputs[:,5:]
    if len(p_cls.shape)==1:
        p_cls = np.expand_dims(p_cls,1)
    cls_id = np.argmax(p_cls,axis=1)
    p_x1 = np.expand_dims(c_x-w/2,-1)
    p_y1 = np.expand_dims(c_y-h/2,-1)
    p_x2 = np.expand_dims(c_x+w/2,-1)
    p_y2 = np.expand_dims(c_y+h/2,-1)
    areas = np.concatenate((p_x1,p_y1,p_x2,p_y2) , axis=-1)
    areas = areas.tolist()
    ids = cv2.dnn.NMSBoxes(areas,conf,thred_cond,thred_nms)
    if len(ids )>0:
        return np.array(areas)[ids], np.array(conf)[ids],cls_id[ids]
    else:
        return [],[],[]

def infer_img(img0 ,net,model_h,model_w,n1 ,na ,stride,anchor_grid , thred_nms=0.3, thred_cond=0.45):

    img = cv2.resize(img0,[ model_w,model_h], interpolation=cv2.INTER_AREA)
    img = cv2.resize(img0,[model_w, model_h], interpolation = cv2.INTER_AREA)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32) / 255.0
    blob = np.expand_dims(np.transpose(img,(2,0,1)),axis = 0)# 模型推理
    outs = net.run(None,{net.get_inputs()[0].name: blob})[0].squeeze(axis=0)
    outs = cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride)# 检测框计算
    img_h, img_w, _ = np.shape(img0)
    boxes, confs, ids=post_process_opencv(outs, model_h, model_w, img_h, img_w, thred_nms, thred_cond)
    return boxes, confs, ids

if __name__ == "__main__":
    #模型加载
    model_pb_path ='best.onnx'
    so = ort.SessionOptions()
    net =ort.InferenceSession(model_pb_path,so,providers=['CPUExecutionProvider'])
    ser = serial.Serial("/dev/ttyAMA0",115200)
    dic_labels= {
        0 : 'g' ,
        1 : 'r',
        2:'b',
 #       3:'3',
 #       4:'4',
 #       5:'5',
 #       6:'6'
    }
    model_h = 320
    model_w = 320
    nl = 3
    na = 3
    stride=[8.,16.,32.]
    anchors = [[10,13,16,30,33,23],[30,61,62,45,59,119],[116,90,156,198,373,326]]
    anchor_grid = np.asarray(anchors,dtype=np.float32).reshape(nl, -1,2)
    video = 1
    cap = cv2.VideoCapture(video)
    flag_det = False
    ans=[]
    while True :
        success, img = cap.read()
        if not success:
            break  # Break the loop when the video ends
        start_time = time.time()

        boxes, confs, ids = infer_img(
            img, net, model_h, model_w, nl, na, stride, anchor_grid
        )
        end_time = time.time()
        elapsed_time = end_time - start_time
        fps = 1.0 / elapsed_time
        cv2.putText(img, f"FPS: {fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        cv2.imshow("Object Detection", img)
        sorted_objects = sorted(zip(boxes, confs, ids), key=lambda x: x[0][0])
        for box, conf, cls_id in sorted_objects:
            x1, y1, x2, y2 = box.astype(int)
            label = dic_labels.get(cls_id, str(cls_id))
            color = [0, 255, 0]  # Green color for bounding boxes
            line_thickness = 2 
            cv2.rectangle(img, (x1, y1), (x2, y2), color, line_thickness)
            label_text = f"{label}: {conf:.2f}"
            plot_one_box([x1, y1, x2, y2], img, color=color, label=label_text, line_thickness=line_thickness)
            if label == 'g':
                ans.append(1)
            if label == 'r':
                ans.append(2)
            if label == 'b':
                ans.append(3)
        if len(ans)==3:
            if ans[0]==1 and ans[1]==2 and ans[2]==3:  ## 
                print(1)
                ser.write(b'1')
            elif ans[0]==1 and ans[1]==3 and ans[2]==2:##   
                print(2)
                ser.write(b'2')
            elif ans[0]==2 and ans[1]==1 and ans[2]==3:##
                print(3)
                ser.write(b'3')
            elif ans[0]==2 and ans[1]==3 and ans[2]==1:##
                print(4)
                ser.write(b'4')
            elif ans[0]==3 and ans[1]==1 and ans[2]==2:##
                print(5)
                ser.write(b'5')
            elif ans[0]==3 and ans[1]==2 and ans[2]==1:##
                print(6)
                ser.write(b'6')
            elif ans[0]==1 and ans[1]==1 and ans[2]==1:##red
                print(7)
                ser.write(b'7')
            elif ans[0]==2 and ans[1]==2 and ans[2]==2:##green
                print(8)
                ser.write(b'8')
            elif ans[0]==3 and ans[1]==3 and ans[2]==3:##blue
                print(9)
                ser.write(b'9') 
            else:
                print(0)
                ser.write(b'0')
        cv2.imshow("Object Detection", img)
        ans=[]
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
